import torch
import cv2
#from datasets.utils import resize_and_pad, resize
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
from collections import OrderedDict
from torch.nn import functional as F

class PromptLearner(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        #n_cls = len(classnames)
        #obj_text = "localizing " + obj_text
        ctx_init = "considering the relationship of"
        obj_init = ", which is the main object to be segmented"
        sub_init = "these objects are backgrounds and clues, find clues from these backgrounds to locate the previously mentioned objects."
        self.clip_model = clip_model.cuda()
        self.dtype = self.clip_model.dtype
        ctx_dim = self.clip_model.ln_final.weight.shape[0]
        vis_dim = self.clip_model.visual.output_dim
        clip_imsize = self.clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        ctx_init = ctx_init.replace("_", " ")
        n_ctx = len(ctx_init.split(" "))


        prompt = clip.tokenize(ctx_init).cuda()
        obj_prompt = clip.tokenize(obj_init).cuda()
        sub_prompt = clip.tokenize(sub_init).cuda()
        with torch.no_grad():
            embedding = self.clip_model.token_embedding(prompt).type(self.dtype)
            obj_embedding = self.clip_model.token_embedding(obj_prompt).type(self.dtype)
            sub_embedding = self.clip_model.token_embedding(sub_prompt).type(self.dtype)
            #embedding = self.clip_model.token_embedding(prompt)
        ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
        obj_vectors = obj_embedding[0, 1:1+10, :]
        sub_vectors = sub_embedding[0, 1:1+30, :]


        # use given words to initialize context vectors



        # print(f'Initial context: "{prompt_prefix}"')
        # print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)
        self.obj_vectors = nn.Parameter(obj_vectors)
        self.sub_vectors = nn.Parameter(sub_vectors)
        # print("grad:", self.ctx.requires_grad)

        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 4)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 4, ctx_dim))
        ]))



        # classnames = [name.replace("_", " ") for name in classnames]
        # name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        # prompts = [prompt_prefix + " " + name + "." for name in classnames]
        #print("prompts words:",prompts)


        self.n_ctx = n_ctx



    def construct_prompts(self, ctx, obj_shifted, sub_shifted, obj_prefix, sub_suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)


        # print("prefix:",obj_prefix.size())
        # print("ctx:",ctx.size())
        # print("suffix:",sub_suffix.size())
        prompts = torch.cat(
            [
                obj_prefix,  # (dim0, 1, dim)
                obj_shifted,
                ctx,  # (dim0, n_ctx, dim)
                sub_suffix,  # (dim0, *, dim)
                sub_shifted,
            ],
            dim=1,
        )

        return prompts

    def forward(self, im_features, obj_prefix, sub_suffix):
        # obj_prefix = self.obj_prefix
        # sub_suffix = self.sub_suffix


        # obj_prompt = clip.tokenize(obj_text).cuda()
        # with torch.no_grad():
        #     obj_embedding = self.clip_model.token_embedding(obj_prompt).type(self.dtype)
        # print("obj_embedding:",obj_embedding[:,:16,:].size())
        # obj_prefix = obj_embedding[:,:16,:]
        #self.register_buffer("obj_prefix", obj_embedding[:,:16,:])

        #print("obj_prefix:",obj_prefix.size())

        ctx = self.ctx  # (n_ctx, ctx_dim)
        bias = self.meta_net(im_features)  # (batch, ctx_dim)
        #print("bias:",bias.size())
        batch_size = bias.size(0)
        bias = bias.unsqueeze(1)  # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)  # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias  # (batch, n_ctx, ctx_dim)
        obj_shifted = self.obj_vectors + bias
        sub_shifted = self.sub_vectors + bias



        # sub_prompt = clip.tokenize(sub_text).cuda()
        # with torch.no_grad():
        #     sub_embedding = self.clip_model.token_embedding(sub_prompt).type(self.dtype)
        # print("sub_embedding:",sub_embedding[:,:57,:].size())
        # sub_suffix = sub_embedding[:,:57,:]
        #self.register_buffer("sub_suffix",sub_embedding[:,:57,:])


        prompts = self.construct_prompts(ctx_shifted, obj_shifted, sub_shifted, obj_prefix, sub_suffix)

        return prompts

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        #x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise RuntimeError("activation should be relu/gelu, not %s." % activation)

class CrossAttention(nn.Module):
    r"""DecoderLayer is mainly made up of the proposed cross-modal relation attention (CMRA).

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    """

    def __init__(self, d_model, nhead, dim_feedforward=1024, dropout=0.2, activation="relu"):
        super(CrossAttention, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory):
        r"""Pass the inputs (and mask) through the decoder layer.
        """
        memory = torch.cat([memory, tgt], dim=0)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        if hasattr(self, "activation"):
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        else:  # for backward compatibility
            tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        return tgt